"""
[Image]
python -m evaluation.caption2image \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model MeaCap \
    --keywords_model gpt-4o-2024-08-06 \
    --correct_model default \
    --candidate_num 50 \
    --key_num 5 \
    --temperature 0.05 \
    --filter_th 0.15 \
    --cc_method spearman \
    --img_gen_model FLUX.1-schnell \
    --device cuda:0

## For BrainSCUBA
python -m evaluation.caption2image \
    --subject_names subj01 subj02 subj05 subj07 \
    --atlasname streams floc-faces floc-places floc-bodies floc-words \
    --betas_norm \
    --modality image \
    --modality_hparam default \
    --model_name CLIP-ViT-B-32 \
    --reduce_dims default 0 \
    --dataset_name OpenImages \
    --max_samples full \
    --dataset_path ./data/OpenImages/frames_518x518px \
    --dataset_captioner MiniCPM-Llama3-V-2_5 \
    --voxel_selection pvalues_corrected 0.05 \
    --layer_selection best \
    --caption_model brainscuba \
    --keywords_model default \
    --correct_model default \
    --candidate_num 1 \
    --key_num -1 \
    --temperature -1 \
    --filter_th -1 \
    --tau 150 \
    --cc_method spearman \
    --img_gen_model FLUX.1-schnell \
    --device cuda:0
"""

import torch
import argparse
import os
import json
from tqdm import tqdm
from utils.utils import (
    make_filename, search_best_layer, TrnVal, 
    collect_fmri_byroi_for_nsd, create_volume_index_and_weight_map
)
import numpy as np
from utils.nsd_access import NSDAccess
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, FluxPipeline

# nltk.download('punkt')

torch.manual_seed(42)

def load_resp_wholevoxels_for_nsd(subject_name, dataset="all", atlas="streams"):
    resp_trn = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="TRAIN",
                                                         atlasname=atlas)
    resp_val = collect_fmri_byroi_for_nsd(subject_name,
                                                         trainvalid="VALID",
                                                         atlasname=atlas)

    return TrnVal(trn=resp_trn, val=resp_val)

    
def main(args):
    score_root_path = "./data/nsd/encoding"
    modality = args.modality
    modality_hparam = args.modality_hparam
    model_name = args.model_name
    file_type = args.voxel_selection[0]
    threshold = float(args.voxel_selection[1])
    nsda = NSDAccess('./data/NSD')

    
    sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    
    # genmodel_name = "stable-diffusion-2-1-base"
    genmodel_name = args.img_gen_model
    if genmodel_name == "stable-diffusion-2-1-base":
        model_id = f"stabilityai/{genmodel_name}"
        pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        print(pipe.scheduler.config)
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        pipe = pipe.to("cuda")
        # DPM-Solver++ (2nd order) を設定
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        pipe.scheduler.config.algorithm_type = "dpmsolver++"
        pipe.scheduler.config.order = 2
        num_inference_steps = 50  # ステップ数を50に設定
    elif genmodel_name == "FLUX.1-schnell":
        model_id = f"black-forest-labs/{genmodel_name}"
        pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(args.device)
        num_inference_steps = 4
        
    width = 512
    height = 512
    # シード値を固定
    seed = 42  # 任意のシード値
    generator = torch.manual_seed(seed)
    
    for subject_name in args.subject_names:
        print(subject_name)
        filename = make_filename(args.reduce_dims[0:2])

        print(f"Modality: {modality}, Modality hparams: {modality_hparam}, Feature: {model_name}, Filename: {filename}")
        # loading the selected layer per subject
        model_score_dir = f"{score_root_path}/{subject_name}/scores/{modality}/{modality_hparam}/{model_name}"
        if args.layer_selection == "best":
            target_best_cv_layer, _, _ = search_best_layer(model_score_dir, filename, select_topN="all")
        else:
            target_best_cv_layer = args.layer_selection
        print(f"Best layer: {target_best_cv_layer}")

        # Random Select
        # np.random.seed(seed=42)
        # target_top_ind = np.random.choice(target_top_ind, len(target_top_ind), replace=False)
        volume_index, weight_index_map, target_top_voxels = create_volume_index_and_weight_map(
            subject_name=subject_name,
            file_type=file_type,
            threshold=threshold,
            model_score_dir=model_score_dir,
            target_best_cv_layer=target_best_cv_layer,
            filename=filename,
            nsda=nsda,
            atlasnames=args.atlasname  # args.atlasname がリストであることを想定
        )

        stim_root_path = "./data/stim_features/nsd"
        if args.reduce_dims[0] != "default":
            try:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.npy"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True).item()
            except:
                reducer_proj_path = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.pkl"
                reducer_projector = np.load(reducer_proj_path, allow_pickle=True)
        else:
            reducer_projector = None
        print(reducer_projector)
        
        for idx, voxel_index in enumerate(tqdm(volume_index)):
            try:
                print(f"voxel_index: {voxel_index}")
                vindex_pad = str(voxel_index).zfill(5)
                resp_save_path = f"./data/nsd/insilico/{subject_name}/{args.dataset_name}_{args.max_samples}/{modality}/{modality_hparam}/{model_name}_{make_filename(args.reduce_dims[0:2])}/whole/voxel{vindex_pad}"

                temp_file_path = f"{resp_save_path}/temp_gen_image_{args.caption_model}.txt"
                if os.path.exists(temp_file_path):
                    print(f"Simulation for {voxel_index} is being processed.")
                    continue
            
                print(f"Now processing: {voxel_index}")
                open(temp_file_path, 'a').close()
                
                image_save_dir = f"{resp_save_path}/gen_images"
                os.makedirs(image_save_dir, exist_ok=True)
                    
                if args.caption_model == "brainscuba":
                    if args.betas_norm:
                        caption_file_path = os.path.join(resp_save_path, f"caption_{args.caption_model}_tau{args.tau}_betanorm.txt")
                    else:
                        caption_file_path = os.path.join(resp_save_path, f"caption_{args.caption_model}_tau{args.tau}.txt")
                    with open(caption_file_path, "r") as f:
                        caption = f.read()
                else:
                    keys_and_text_file_path = os.path.join(resp_save_path, f"keys_and_text_{args.caption_model}_kmodel_{args.keywords_model}_{args.key_num}keys_{args.temperature}temp_{args.filter_th}th_{args.candidate_num}cands_cmodel_{args.correct_model}.json")
                    with open(keys_and_text_file_path, "r") as f:
                        keys_and_text = json.load(f)
                    caption = keys_and_text["text"]
                print(caption)

                if args.caption_model == "brainscuba":
                    image_base_name = os.path.basename(caption_file_path).replace("caption", f"{genmodel_name}_{height}x{width}px").replace(".txt", ".png")
                else:
                    image_base_name = os.path.basename(keys_and_text_file_path).replace("keys_and_text", f"{genmodel_name}_{height}x{width}px").replace(".json", ".png")
                
                image_save_path = f"{image_save_dir}/{image_base_name}"
                if os.path.exists(image_save_path):
                    print(f"Already processed.")
                    continue
                # 画像生成
                if genmodel_name == "stable-diffusion-2-1-base":
                    image = pipe(caption, num_inference_steps=num_inference_steps, height=height, width=width, generator=generator).images[0]
                elif genmodel_name == "FLUX.1-schnell":
                    image = pipe(
                        caption,
                        guidance_scale=0.0,
                        num_inference_steps=4,
                        max_sequence_length=256,
                        height=height,
                        width=width,
                        generator=generator
                    ).images[0]
                # 画像を保存

                image.save(image_save_path)
                    
            finally:
                try:
                    os.remove(temp_file_path)
                except:
                    pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--subject_names",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--atlasname",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--betas_norm",
        action="store_true"
    )
    parser.add_argument(
        "--modality",
        type=str,
        required=True,
        help="Name of the modality to use."
    )
    parser.add_argument(
        "--modality_hparam",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--reduce_dims",
        nargs="*",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--max_samples",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--dataset_captioner",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--voxel_selection",
        nargs="*",
        type=str,
        required=True,
        help="Selection method of voxels. Implemented type are 'uv' and 'share'."
    )
    parser.add_argument(
        "--layer_selection",
        type=str,
        required=False,
        default="best",
    )
    parser.add_argument(
        "--caption_model",
        type=str,
        required=True,
        help="Name of the captioning model to use."
    )
    parser.add_argument(
        "--keywords_model",
        type=str,
        required=False,
    )
    parser.add_argument(
        "--correct_model",
        type=str,
        required=True,
        choices=["None", "default", "gpt-4o-2024-08-06"],
        help="Name of the correction model to use."
    )
    parser.add_argument(
        "--candidate_num",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--key_num",
        type=int,
        required=True,
    )
    parser.add_argument(
        "--temperature",
        type=float,
        required=True
    )
    parser.add_argument(
        "--filter_th",
        type=float,
        required=False,
    )
    parser.add_argument(
        "--cc_method",
        type=str,
        required=True,
        choices=["spearman", "pearson"],
    )
    parser.add_argument(
        "--tau",
        type=float,
        required=False
    )
    parser.add_argument(
        "--img_gen_model",
        type=str,
        required=True
    )
    parser.add_argument(
        "--device",
        type=str,
        required=True,
        help="Device to use."
    )
    parser.add_argument(
        "--embs_only",
        action="store_true",
        required=False,
        default=False
    )
    args = parser.parse_args()
    main(args)